import torch
import torch.nn as nn

net = nn.Sequential(
    nn.Linear(20, 256), nn.ReLU(),
    nn.Linear(256, 128), nn.ReLU(),
    nn.Linear(128, 10))

# 添加额外的线性层
extra_layer = nn.Linear(10, 256)

# 将第一个线性层与额外的线性层的权重进行绑定
net[0].weight = extra_layer.weight

# 使用新的输入（维度为20）调用模型
X = torch.rand(2, 10)
print(net(X))